from scipy import *
import matplotlib.pyplot as plt

lambd = 930

to = 788
gamma = 6.8
sic = 6.49 + 3.23/(1 - (lambd/to)**2 -1j*gamma*lambd/to**2)
lambd = 1e4/lambd
k0 = 2*pi/(lambd*1e-6)
ge = 16

geo2 = 1.1**2
air = 1

num_d = 200
nbeta = 201
htotal = 7e-9
d_geo2 = linspace(0e-9, 7e-9, num_d)
d_ge = htotal - d_geo2

beta = 2*pi/linspace(50e-9, 2000e-9, nbeta)
k_ge = sqrt(k0**2*ge - beta**2)
k_geo2 = sqrt(k0**2*geo2 - beta**2)
k_sic = sqrt(k0**2*sic - beta**2)
k_air = sqrt(k0**2*air - beta**2)
R = zeros((num_d, nbeta), dtype = 'cfloat')
r12 = (geo2*k_air - air*k_geo2)/(geo2*k_air + air*k_geo2)
r23 = (ge*k_geo2 - geo2*k_ge)/(ge*k_geo2 + geo2*k_ge)
r34 = (sic*k_ge - ge*k_sic)/(sic*k_ge + ge*k_sic)    
for i in range(num_d):
    for j in range(nbeta):
        phi2 = k_geo2[j] * d_geo2[i]
        phi3 = k_ge[j] * d_ge[i]
        M = dot(array([[1,r12[j]], [r12[j], 1]]), array([[exp(-1j*phi2), 0], [0, exp(1j*phi2)]]))
        M = dot(M, array([[1, r23[j]], [r23[j], 1]]))
        M = dot(M, array([[exp(-1j*phi3), 0], [0, exp(1j*phi3)]]))
        M = dot(M, array([[1, r34[j]], [r34[j], 1]]))
        R[i, j] = M[1, 0]/M[0, 0]

R = imag(R)
indt = zeros(num_d, dtype = 'int64')
indt[:] = float('nan')
for i in range(num_d):
    indt[i] = argmax(R[i, :])


from scipy.optimize import root
h = d_geo2
d = d_ge
eps_h = geo2
eps_d = ge
eps3 = 1
eps2 = sic
beta_analytical = zeros(len(h), dtype='cfloat')
beta_approx = 2*pi/500e-9

def make_dispersion(h, d):
	def four_layer_dispersion(beta):
		beta = beta[0]+beta[1]*1j
		left = exp(2*beta*(h-d))*(eps_h-eps_d)*(eps3-eps_h)*(eps2+eps_d) - exp(2*beta*h)*(eps_h+eps_d)*(eps3-eps_h)*(eps2-eps_d)+exp(-2*beta*d)*(eps_h+eps_d)*(eps3+eps_h)*(eps2+eps_d)
		right = (eps_h-eps_d)*(eps3+eps_h)*(eps2-eps_d)
		res = left-right
		return [res.real, res.imag]
	return four_layer_dispersion

for i in range(len(h)):	
	four_layer = make_dispersion(h[i], d[i])
	if not i:
		beta_guess = beta_approx
	else:
		beta_guess = beta_analytical[i-1]
	sol =  root(four_layer, [beta_guess.real, beta_guess.imag])
	if sol.success:
		if sol.x[1]>0:
			beta_analytical[i] = sol.x[0]+sol.x[1]*1j
		else:
			beta_analytical[i] = -sol.x[0]-sol.x[1]*1j

lambdp = ((2*eps_h*eps_d*(eps2+eps3)+(eps2*eps3-eps_h*eps_d)*(eps_h-eps_d))*d + eps_d*(eps_h-eps2)*(eps_h-eps3)*htotal)/(-eps_h*eps_d*(eps2+eps3)/2/pi)

# new linear approximation 20190331
eps_s = sic
alpha1 = (eps_h - eps_d)*(1-eps_h)*(eps_s+eps_d)
alpha2 = (eps_h + eps_d)*(1-eps_h)*(eps_d-eps_s)
alpha3 = (eps_h + eps_d)*(1+eps_h)*(eps_s+eps_d)
alpha4 = (eps_h - eps_d)*(1+eps_h)*(eps_s-eps_d)

T = 7e-9
lambdp2 = ((alpha1 + alpha4) * d_geo2 - (alpha4 - alpha2) * T)*4*pi/(alpha4-alpha1-alpha2-alpha3)
lambdp3 = ((2*alpha1+alpha2+alpha3)*d_geo2 - (alpha4-alpha2)*T) * 4* pi /(alpha4-alpha1-alpha2-alpha3)
lambdp4 = ((alpha1 + alpha4) * d_geo2 - (alpha4 - alpha2) * T)*4*pi/(alpha4-alpha1-alpha2-alpha3)
lambdp5 = ((2*alpha4 - alpha2 - alpha3) * d_geo2 - (alpha4 - alpha2) * T)*4*pi/(alpha4-alpha1-alpha2-alpha3)
lambdp7 = ((2*alpha4 - alpha2 - alpha3) * d_geo2 - (alpha1 + alpha3) * T)*4*pi/(alpha4-alpha1-alpha2-alpha3)


ksp_day1 = 2*pi/345e-9
idx1 = argmin(abs(ksp_day1+beta_analytical.real))
ksp_day10 = 2*pi/360e-9
ksp_day16 = 2*pi/400e-9
idx10 = argmin(abs(ksp_day10 + beta_analytical.real))
idx16 = argmin(abs(ksp_day16 + beta_analytical.real))
print(beta_analytical[10].imag/beta_analytical[10].real)
print(beta_analytical[idx1].imag/beta_analytical[idx1].real)
print(beta_analytical[idx10].imag/beta_analytical[idx10].real)
print(beta_analytical[idx16].imag/beta_analytical[idx16].real)



save = 0
if save:
	fig, ax = plt.subplots()
	plt.plot(-2*pi/real(beta_analytical)*1e9, d_geo2*1e9,  linewidth = 2, color = 'r', label='analytical')
	plt.plot(2*pi/beta[indt]*1e9, d_geo2*1e9, linewidth = 2, color = 'k', label='max Imag(R)')
	plt.plot(-lambdp4*1e9, d_geo2*1e9, linewidth = 2, color = 'g', label='Formula 4')
	plt.plot(-lambdp5*1e9, d_geo2*1e9, linewidth = 2, color = 'm', label='Formula 5')
	plt.plot(-lambdp7*1e9, d_geo2*1e9, linewidth = 2, color = 'c', label='Formula 7')
	ind1 = argmin(abs(2*pi/beta - 345e-9))
	ind2 = argmin(abs(2*pi/beta - 360e-9))
	ind3 = argmin(abs(2*pi/beta - 400e-9))
	# figW, figH = ax.get_figure().get_size_inches()
	# _, _, w, h = ax.get_position().bounds
	# plt.plot(lambdp*1e9, h*1e9, linewidth = 2, color = 'w', label='linear model')
	# plt.annotate('345nm', xy = (345, d_geo2[indt[ind1]]*1e9), 
	#              xytext = (2*pi/beta[ind1]*1e9, d_geo2[indt[ind1]]*1e9 + 1),
	#             arrowprops = dict(facecolor='white'), color = 'white')
	# plt.annotate('360nm', xy = (360, d_geo2[indt[ind2]]*1e9), 
	#              xytext = (2*pi/beta[ind2]*1e9, d_geo2[indt[ind2]]*1e9 - 1.2),
	#             arrowprops = dict(facecolor='white'), color = 'white')
	# plt.annotate('400nm', xy = (400, d_geo2[indt[ind3]]*1e9), 
	#              xytext = (2*pi/beta[ind3]*1e9, d_geo2[indt[ind3]]*1e9 + 1),
	#             arrowprops = dict(facecolor='white'), color = 'white')
	plt.xlabel('Polariton Wavelength (nm)', fontsize = 12, fontweight = 'bold')
	plt.ylabel('GeO2 thickness', fontsize = 12, fontweight = 'bold')
	plt.gca().set_xticks(range(50, 850, 150))
	plt.xlim(2*pi/max(beta)*1e9, 2*pi/min(beta)*1e9)
	plt.ylim(0, 7)
	plt.legend(frameon=False)
	w, h = plt.gcf().get_size_inches()
	w *= 1.5 * 0.9
	plt.gcf().set_size_inches((w, h))
	lambd0 = 924
	plt.savefig('lines_{}.svg'.format(int(lambd0)), dpi=300, background=False, transparent=True)

	fig, ax = plt.subplots()
	ax.pcolor(2*pi/beta*1e9, d_geo2*1e9, R)
	w, h = plt.gcf().get_size_inches()
	w *= 1.5 * 0.9
	plt.gcf().set_size_inches((w, h))
	plt.axis('off')
	plt.savefig('contour_{}.png'.format(int(lambd0)), dpi=300, background=False, transparent=True)

else:
	fig, ax = plt.subplots()
	plt.plot(-2*pi/real(beta_analytical)*1e9, d_geo2*1e9,  linewidth = 2, color = 'r', label='analytical')
	plt.plot(2*pi/beta[indt]*1e9, d_geo2*1e9, linewidth = 2, color = 'k', label='max Imag(R)')
	plt.plot(-lambdp4*1e9, d_geo2*1e9, linewidth = 2, color = 'g', label='Formula 4')
	plt.plot(-lambdp5*1e9, d_geo2*1e9, linewidth = 2, color = 'm', label='Formula 5')
	plt.plot(-lambdp7*1e9, d_geo2*1e9, linewidth = 2, color = 'c', label='Formula 7')
	# plt.annotate('345nm', xy = (345, d_geo2[indt[ind1]]*1e9), 
	#              xytext = (2*pi/beta[ind1]*1e9, d_geo2[indt[ind1]]*1e9 + 1),
	#             arrowprops = dict(facecolor='white'), color = 'white')
	# plt.annotate('360nm', xy = (360, d_geo2[indt[ind2]]*1e9), 
	#              xytext = (2*pi/beta[ind2]*1e9, d_geo2[indt[ind2]]*1e9 - 1.2),
	#             arrowprops = dict(facecolor='white'), color = 'white')
	# plt.annotate('400nm', xy = (400, d_geo2[indt[ind3]]*1e9), 
	#              xytext = (2*pi/beta[ind3]*1e9, d_geo2[indt[ind3]]*1e9 + 1),
	#             arrowprops = dict(facecolor='white'), color = 'white')
	plt.xlabel('Polariton Wavelength (nm)', fontsize = 12, fontweight = 'bold')
	plt.ylabel('GeO2 thickness', fontsize = 12, fontweight = 'bold')
	plt.xlim(2*pi/max(beta)*1e9, 2*pi/min(beta)*1e9)
	plt.ylim(0, 7)
	plt.legend(frameon=False)
	ax.pcolor(2*pi/beta*1e9, d_geo2*1e9, R)
	w, h = plt.gcf().get_size_inches()
	w *= 1.5 * 0.9
	plt.gcf().set_size_inches((w, h))
	plt.tight_layout()
	plt.show()


# wn = 1e4/lambd
# d_geo2 = d_geo2.reshape((-1,1))
# base =r'data{}/'.format(int(wn))
# savetxt(base+'thickness_GeO_axis.txt', d_geo2)
# savetxt(base+'Imag_reflection_{}.txt'.format(int(wn)), R)
# savetxt(base+'analytical_model_{}.txt'.format(int(wn)), abs(2*pi/real(beta_analytical)))
# savetxt(base+'max_imag_R_{}.txt'.format(int(wn)), 2*pi/beta[indt])
# savetxt(base+'polariton_wavelength_axis.txt', 2*pi/beta.reshape(-1, 1))
